#  Version: 2020.02.21
#
#  MIT License
#
#  Copyright (c) 2018 Jiankang Deng and Jia Guo
#
#  Permission is hereby granted, free of charge, to any person obtaining a copy
#  of this software and associated documentation files (the "Software"), to deal
#  in the Software without restriction, including without limitation the rights
#  to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
#  copies of the Software, and to permit persons to whom the Software is
#  furnished to do so, subject to the following conditions:
#
#  The above copyright notice and this permission notice shall be included in all
#  copies or substantial portions of the Software.
#
#  THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
#  IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
#  FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
#  AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
#  LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
#  OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
#  SOFTWARE.
#

# THIS FILE IS FOR EXPERIMENTS, USE train_softmax.py FOR NORMAL TRAINING.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import sys
import math
import random
import logging
import pickle
import numpy as np
from data import FaceImageIter
from data import FaceImageIterList
import mxnet as mx
from mxnet import ndarray as nd
import argparse
import mxnet.optimizer as optimizer

sys.path.append(os.path.join(os.path.dirname(__file__), 'common'))
import face_image
from noise_sgd import NoiseSGD

sys.path.append(os.path.join(os.path.dirname(__file__), 'eval'))
sys.path.append(os.path.join(os.path.dirname(__file__), 'symbols'))
import fresnet
import finception_resnet_v2
import fmobilenet
import fmobilenetv2
import fxception
import fdensenet
import fdpn
import fnasnet
import spherenet
# import lfw
import verification
import sklearn

sys.path.append(os.path.join(os.path.dirname(__file__), 'losses'))
import center_loss

logger = logging.getLogger()
logger.setLevel(logging.INFO)

args = None


class AccMetric(mx.metric.EvalMetric):
    def __init__(self):
        self.axis = 1
        super(AccMetric, self).__init__(
            'acc', axis=self.axis,
            output_names=None, label_names=None)
        self.losses = []
        self.count = 0

    def update(self, labels, preds):
        self.count += 1
        if args.loss_type >= 2 and args.loss_type <= 7 and args.margin_verbose > 0:
            if self.count % args.ctx_num == 0:
                mbatch = self.count // args.ctx_num
                _verbose = args.margin_verbose
                if mbatch == 1 or mbatch % _verbose == 0:
                    a = 0.0
                    b = 0.0
                    if len(preds) >= 4:
                        a = preds[-2].asnumpy()[0]
                        b = preds[-1].asnumpy()[0]
                    elif len(preds) == 3:
                        a = preds[-1].asnumpy()[0]
                        b = a
                    print('[%d][MARGIN]%f,%f' % (mbatch, a, b))
        if args.logits_verbose > 0:
            if self.count % args.ctx_num == 0:
                mbatch = self.count // args.ctx_num
                _verbose = args.logits_verbose
                if mbatch == 1 or mbatch % _verbose == 0:
                    a = 0.0
                    b = 0.0
                    if len(preds) >= 3:
                        v = preds[-1].asnumpy()
                        v = np.sort(v)
                        num = len(v) // 10
                        a = np.mean(v[0:num])
                        b = np.mean(v[-1 * num:])
                    print('[LOGITS] %d,%f,%f' % (mbatch, a, b))
        # loss = preds[2].asnumpy()[0]
        # if len(self.losses)==20:
        #  print('ce loss', sum(self.losses)/len(self.losses))
        #  self.losses = []
        # self.losses.append(loss)
        preds = [preds[1]]  # use softmax output
        for label, pred_label in zip(labels, preds):
            # print(pred_label)
            # print(label.shape, pred_label.shape)
            if pred_label.shape != label.shape:
                pred_label = mx.ndarray.argmax(pred_label, axis=self.axis)
            pred_label = pred_label.asnumpy().astype('int32').flatten()
            label = label.asnumpy()
            if label.ndim == 2:
                label = label[:, 0]
            label = label.astype('int32').flatten()
            # print(label)
            # print('label',label)
            # print('pred_label', pred_label)
            assert label.shape == pred_label.shape
            self.sum_metric += (pred_label.flat == label.flat).sum()
            self.num_inst += len(pred_label.flat)


class LossValueMetric(mx.metric.EvalMetric):
    def __init__(self):
        self.axis = 1
        super(LossValueMetric, self).__init__(
            'lossvalue', axis=self.axis,
            output_names=None, label_names=None)
        self.losses = []

    def update(self, labels, preds):
        loss = preds[-1].asnumpy()[0]
        self.sum_metric += loss
        self.num_inst += 1.0
        gt_label = preds[-2].asnumpy()
        # print(gt_label)


def parse_args():
    parser = argparse.ArgumentParser(description='Train face network')
    # general
    parser.add_argument('--data-dir', default='', help='training set directory')
    parser.add_argument('--prefix', default='../model/model', help='directory to save model.')
    parser.add_argument('--pretrained', default='', help='pretrained model to load')
    parser.add_argument('--ckpt', type=int, default=1,
                        help='checkpoint saving option. 0: discard saving. 1: save when necessary. 2: always save')
    parser.add_argument('--network', default='r50', help='specify network')
    parser.add_argument('--version-se', type=int, default=0, help='whether to use se in network')
    parser.add_argument('--version-input', type=int, default=1, help='network input config')
    parser.add_argument('--version-output', type=str, default='E', help='network embedding output config')
    parser.add_argument('--version-unit', type=int, default=3, help='resnet unit config')
    parser.add_argument('--version-act', type=str, default='prelu', help='network activation config')
    parser.add_argument('--end-epoch', type=int, default=100000, help='training epoch size.')
    parser.add_argument('--noise-sgd', type=float, default=0.0, help='')
    parser.add_argument('--lr', type=float, default=0.1, help='start learning rate')
    parser.add_argument('--wd', type=float, default=0.0005, help='weight decay')
    parser.add_argument('--mom', type=float, default=0.9, help='momentum')
    parser.add_argument('--emb-size', type=int, default=512, help='embedding length')
    parser.add_argument('--per-batch-size', type=int, default=128, help='batch size in each context')
    parser.add_argument('--margin-m', type=float, default=0.5, help='')
    parser.add_argument('--margin-s', type=float, default=64.0, help='')
    parser.add_argument('--margin-a', type=float, default=0.0, help='')
    parser.add_argument('--margin-b', type=float, default=0.0, help='')
    parser.add_argument('--easy-margin', type=int, default=0, help='')
    parser.add_argument('--margin-verbose', type=int, default=0, help='')
    parser.add_argument('--logits-verbose', type=int, default=0, help='')
    parser.add_argument('--c2c-threshold', type=float, default=0.0, help='')
    parser.add_argument('--c2c-mode', type=int, default=-10, help='')
    parser.add_argument('--output-c2c', type=int, default=0, help='')
    parser.add_argument('--train-limit', type=int, default=0, help='')
    parser.add_argument('--margin', type=int, default=4, help='')
    parser.add_argument('--beta', type=float, default=1000., help='')
    parser.add_argument('--beta-min', type=float, default=5., help='')
    parser.add_argument('--beta-freeze', type=int, default=0, help='')
    parser.add_argument('--gamma', type=float, default=0.12, help='')
    parser.add_argument('--power', type=float, default=1.0, help='')
    parser.add_argument('--scale', type=float, default=0.9993, help='')
    parser.add_argument('--center-alpha', type=float, default=0.5, help='')
    parser.add_argument('--center-scale', type=float, default=0.003, help='')
    parser.add_argument('--images-per-identity', type=int, default=0, help='')
    parser.add_argument('--triplet-bag-size', type=int, default=3600, help='')
    parser.add_argument('--triplet-alpha', type=float, default=0.3, help='')
    parser.add_argument('--triplet-max-ap', type=float, default=0.0, help='')
    parser.add_argument('--verbose', type=int, default=2000, help='')
    parser.add_argument('--loss-type', type=int, default=4, help='')
    parser.add_argument('--incay', type=float, default=0.0, help='feature incay')
    parser.add_argument('--use-deformable', type=int, default=0, help='')
    parser.add_argument('--rand-mirror', type=int, default=1, help='')
    parser.add_argument('--cutoff', type=int, default=0, help='')
    parser.add_argument('--patch', type=str, default='0_0_96_112_0', help='')
    parser.add_argument('--lr-steps', type=str, default='', help='')
    parser.add_argument('--max-steps', type=int, default=0, help='')
    parser.add_argument('--target', type=str, default='lfw,cfp_fp,agedb_30,cplfw,calfw', help='')
    args = parser.parse_args()
    return args


def get_symbol(args, arg_params, aux_params):
    data_shape = (args.image_channel, args.image_h, args.image_w)
    image_shape = ",".join([str(x) for x in data_shape])
    margin_symbols = []
    if args.network[0] == 'd':
        embedding = fdensenet.get_symbol(args.emb_size, args.num_layers,
                                         version_se=args.version_se, version_input=args.version_input,
                                         version_output=args.version_output, version_unit=args.version_unit)
    elif args.network[0] == 'm':
        print('init mobilenet', args.num_layers)
        if args.num_layers == 1:
            embedding = fmobilenet.get_symbol(args.emb_size,
                                              version_se=args.version_se, version_input=args.version_input,
                                              version_output=args.version_output, version_unit=args.version_unit)
        else:
            embedding = fmobilenetv2.get_symbol(args.emb_size)
    elif args.network[0] == 'i':
        print('init inception-resnet-v2', args.num_layers)
        embedding = finception_resnet_v2.get_symbol(args.emb_size,
                                                    version_se=args.version_se, version_input=args.version_input,
                                                    version_output=args.version_output, version_unit=args.version_unit)
    elif args.network[0] == 'x':
        print('init xception', args.num_layers)
        embedding = fxception.get_symbol(args.emb_size,
                                         version_se=args.version_se, version_input=args.version_input,
                                         version_output=args.version_output, version_unit=args.version_unit)
    elif args.network[0] == 'p':
        print('init dpn', args.num_layers)
        embedding = fdpn.get_symbol(args.emb_size, args.num_layers,
                                    version_se=args.version_se, version_input=args.version_input,
                                    version_output=args.version_output, version_unit=args.version_unit)
    elif args.network[0] == 'n':
        print('init nasnet', args.num_layers)
        embedding = fnasnet.get_symbol(args.emb_size)
    elif args.network[0] == 's':
        print('init spherenet', args.num_layers)
        embedding = spherenet.get_symbol(args.emb_size, args.num_layers)
    else:
        print('init resnet', args.num_layers)
        embedding = fresnet.get_symbol(args.emb_size, args.num_layers,
                                       version_se=args.version_se, version_input=args.version_input,
                                       version_output=args.version_output, version_unit=args.version_unit,
                                       version_act=args.version_act)
    all_label = mx.symbol.Variable('softmax_label')
    if not args.output_c2c:
        gt_label = all_label
    else:
        gt_label = mx.symbol.slice_axis(all_label, axis=1, begin=0, end=1)
        gt_label = mx.symbol.reshape(gt_label, (args.per_batch_size,))
        c2c_label = mx.symbol.slice_axis(all_label, axis=1, begin=1, end=2)
        c2c_label = mx.symbol.reshape(c2c_label, (args.per_batch_size,))
    assert args.loss_type >= 0
    extra_loss = None
    if args.loss_type == 0:  # softmax
        _weight = mx.symbol.Variable('fc7_weight')
        _bias = mx.symbol.Variable('fc7_bias', lr_mult=2.0, wd_mult=0.0)
        fc7 = mx.sym.FullyConnected(data=embedding, weight=_weight, bias=_bias, num_hidden=args.num_classes, name='fc7')
    elif args.loss_type == 1:  # sphere
        _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0)
        _weight = mx.symbol.L2Normalization(_weight, mode='instance')
        fc7 = mx.sym.LSoftmax(data=embedding, label=gt_label, num_hidden=args.num_classes,
                              weight=_weight,
                              beta=args.beta, margin=args.margin, scale=args.scale,
                              beta_min=args.beta_min, verbose=1000, name='fc7')
    elif args.loss_type == 8:  # centerloss, TODO
        _weight = mx.symbol.Variable('fc7_weight')
        _bias = mx.symbol.Variable('fc7_bias', lr_mult=2.0, wd_mult=0.0)
        fc7 = mx.sym.FullyConnected(data=embedding, weight=_weight, bias=_bias, num_hidden=args.num_classes, name='fc7')
        print('center-loss', args.center_alpha, args.center_scale)
        extra_loss = mx.symbol.Custom(data=embedding, label=gt_label, name='center_loss', op_type='centerloss', \
                                      num_class=args.num_classes, alpha=args.center_alpha, scale=args.center_scale,
                                      batchsize=args.per_batch_size)
    elif args.loss_type == 2:
        s = args.margin_s
        m = args.margin_m
        _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0)
        _weight = mx.symbol.L2Normalization(_weight, mode='instance')
        if s > 0.0:
            nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n') * s
            fc7 = mx.sym.FullyConnected(data=nembedding, weight=_weight, no_bias=True, num_hidden=args.num_classes,
                                        name='fc7')
            if m > 0.0:
                if args.margin_verbose > 0:
                    zy = mx.sym.pick(fc7, gt_label, axis=1)
                    cos_t = zy / s
                    margin_symbols.append(mx.symbol.mean(cos_t))

                s_m = s * m
                gt_one_hot = mx.sym.one_hot(gt_label, depth=args.num_classes, on_value=s_m, off_value=0.0)
                fc7 = fc7 - gt_one_hot

                if args.margin_verbose > 0:
                    new_zy = mx.sym.pick(fc7, gt_label, axis=1)
                    new_cos_t = new_zy / s
                    margin_symbols.append(mx.symbol.mean(new_cos_t))
        else:
            fc7 = mx.sym.FullyConnected(data=embedding, weight=_weight, no_bias=True, num_hidden=args.num_classes,
                                        name='fc7')
            if m > 0.0:
                body = embedding * embedding
                body = mx.sym.sum_axis(body, axis=1, keepdims=True)
                body = mx.sym.sqrt(body)
                body = body * m
                gt_one_hot = mx.sym.one_hot(gt_label, depth=args.num_classes, on_value=1.0, off_value=0.0)
                body = mx.sym.broadcast_mul(gt_one_hot, body)
                fc7 = fc7 - body

    elif args.loss_type == 3:
        s = args.margin_s
        m = args.margin_m
        assert args.margin == 2 or args.margin == 4
        _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0)
        _weight = mx.symbol.L2Normalization(_weight, mode='instance')
        nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n') * s
        fc7 = mx.sym.FullyConnected(data=nembedding, weight=_weight, no_bias=True, num_hidden=args.num_classes,
                                    name='fc7')
        zy = mx.sym.pick(fc7, gt_label, axis=1)
        cos_t = zy / s
        if args.margin_verbose > 0:
            margin_symbols.append(mx.symbol.mean(cos_t))
        if m > 1.0:
            t = mx.sym.arccos(cos_t)
            t = t * m
            body = mx.sym.cos(t)
            new_zy = body * s
            if args.margin_verbose > 0:
                new_cos_t = new_zy / s
                margin_symbols.append(mx.symbol.mean(new_cos_t))
            diff = new_zy - zy
            diff = mx.sym.expand_dims(diff, 1)
            gt_one_hot = mx.sym.one_hot(gt_label, depth=args.num_classes, on_value=1.0, off_value=0.0)
            body = mx.sym.broadcast_mul(gt_one_hot, diff)
            fc7 = fc7 + body

        # threshold = math.cos(args.margin_m)
        # cond_v = cos_t - threshold
        # cond = mx.symbol.Activation(data=cond_v, act_type='relu')
        # body = cos_t
        # for i in xrange(args.margin//2):
        #  body = body*body
        #  body = body*2-1
        # new_zy = body*s
        # zy_keep = zy
        # new_zy = mx.sym.where(cond, new_zy, zy_keep)
        # if args.margin_verbose>0:
        #  new_cos_t = new_zy/s
        #  margin_symbols.append(mx.symbol.mean(new_cos_t))
        # diff = new_zy - zy
        # diff = mx.sym.expand_dims(diff, 1)
        # gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0)
        # body = mx.sym.broadcast_mul(gt_one_hot, diff)
        # fc7 = fc7+body
    elif args.loss_type == 4:
        s = args.margin_s
        m = args.margin_m
        assert s > 0.0
        assert m >= 0.0
        assert m < (math.pi / 2)
        _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0)
        _weight = mx.symbol.L2Normalization(_weight, mode='instance')
        nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n') * s
        fc7 = mx.sym.FullyConnected(data=nembedding, weight=_weight, no_bias=True, num_hidden=args.num_classes,
                                    name='fc7')
        zy = mx.sym.pick(fc7, gt_label, axis=1)
        cos_t = zy / s
        if args.margin_verbose > 0:
            margin_symbols.append(mx.symbol.mean(cos_t))
        if args.output_c2c == 0:
            cos_m = math.cos(m)
            sin_m = math.sin(m)
            mm = math.sin(math.pi - m) * m
            # threshold = 0.0
            threshold = math.cos(math.pi - m)
            if args.easy_margin:
                cond = mx.symbol.Activation(data=cos_t, act_type='relu')
            else:
                cond_v = cos_t - threshold
                cond = mx.symbol.Activation(data=cond_v, act_type='relu')
            body = cos_t * cos_t
            body = 1.0 - body
            sin_t = mx.sym.sqrt(body)
            new_zy = cos_t * cos_m
            b = sin_t * sin_m
            new_zy = new_zy - b
            new_zy = new_zy * s
            if args.easy_margin:
                zy_keep = zy
            else:
                zy_keep = zy - s * mm
            new_zy = mx.sym.where(cond, new_zy, zy_keep)
        else:
            # set c2c as cosm^2 in data.py
            cos_m = mx.sym.sqrt(c2c_label)
            sin_m = 1.0 - c2c_label
            sin_m = mx.sym.sqrt(sin_m)
            body = cos_t * cos_t
            body = 1.0 - body
            sin_t = mx.sym.sqrt(body)
            new_zy = cos_t * cos_m
            b = sin_t * sin_m
            new_zy = new_zy - b
            new_zy = new_zy * s

        if args.margin_verbose > 0:
            new_cos_t = new_zy / s
            margin_symbols.append(mx.symbol.mean(new_cos_t))
        diff = new_zy - zy
        diff = mx.sym.expand_dims(diff, 1)
        gt_one_hot = mx.sym.one_hot(gt_label, depth=args.num_classes, on_value=1.0, off_value=0.0)
        body = mx.sym.broadcast_mul(gt_one_hot, diff)
        fc7 = fc7 + body
    elif args.loss_type == 5:
        s = args.margin_s
        m = args.margin_m
        assert s > 0.0
        _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0)
        _weight = mx.symbol.L2Normalization(_weight, mode='instance')
        nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n') * s
        fc7 = mx.sym.FullyConnected(data=nembedding, weight=_weight, no_bias=True, num_hidden=args.num_classes,
                                    name='fc7')
        if args.margin_a != 1.0 or args.margin_m != 0.0 or args.margin_b != 0.0:
            if args.margin_a == 1.0 and args.margin_m == 0.0:
                s_m = s * args.margin_b
                gt_one_hot = mx.sym.one_hot(gt_label, depth=args.num_classes, on_value=s_m, off_value=0.0)
                fc7 = fc7 - gt_one_hot
            else:
                zy = mx.sym.pick(fc7, gt_label, axis=1)
                cos_t = zy / s
                t = mx.sym.arccos(cos_t)
                if args.margin_a != 1.0:
                    t = t * args.margin_a
                if args.margin_m > 0.0:
                    t = t + args.margin_m
                body = mx.sym.cos(t)
                if args.margin_b > 0.0:
                    body = body - args.margin_b
                new_zy = body * s
                diff = new_zy - zy
                diff = mx.sym.expand_dims(diff, 1)
                gt_one_hot = mx.sym.one_hot(gt_label, depth=args.num_classes, on_value=1.0, off_value=0.0)
                body = mx.sym.broadcast_mul(gt_one_hot, diff)
                fc7 = fc7 + body
    elif args.loss_type == 6:
        s = args.margin_s
        m = args.margin_m
        assert s > 0.0
        assert m >= 0.0
        assert m < (math.pi / 2)
        _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0)
        _weight = mx.symbol.L2Normalization(_weight, mode='instance')
        nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n') * s
        fc7 = mx.sym.FullyConnected(data=nembedding, weight=_weight, no_bias=True, num_hidden=args.num_classes,
                                    name='fc7')
        zy = mx.sym.pick(fc7, gt_label, axis=1)
        cos_t = zy / s
        t = mx.sym.arccos(cos_t)
        if args.margin_verbose > 0:
            margin_symbols.append(mx.symbol.mean(t))
        t_min = mx.sym.min(t)
        ta = mx.sym.broadcast_div(t_min, t)

        a1 = args.margin_a
        r1 = ta - a1
        r1 = mx.symbol.Activation(data=r1, act_type='relu')
        r1 = r1 + a1

        r2 = mx.symbol.zeros(shape=(args.per_batch_size,))

        cond = t - 1.0
        cond = mx.symbol.Activation(data=cond, act_type='relu')
        r = mx.sym.where(cond, r2, r1)
        var_m = r * m
        t = t + var_m
        body = mx.sym.cos(t)
        new_zy = body * s
        if args.margin_verbose > 0:
            # new_cos_t = new_zy/s
            # margin_symbols.append(mx.symbol.mean(new_cos_t))
            margin_symbols.append(mx.symbol.mean(t))
        diff = new_zy - zy
        diff = mx.sym.expand_dims(diff, 1)
        gt_one_hot = mx.sym.one_hot(gt_label, depth=args.num_classes, on_value=1.0, off_value=0.0)
        body = mx.sym.broadcast_mul(gt_one_hot, diff)
        fc7 = fc7 + body
    elif args.loss_type == 7:
        s = args.margin_s
        m = args.margin_m
        assert s > 0.0
        assert m >= 0.0
        assert m < (math.pi / 2)
        _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0)
        _weight = mx.symbol.L2Normalization(_weight, mode='instance')
        nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n') * s
        fc7 = mx.sym.FullyConnected(data=nembedding, weight=_weight, no_bias=True, num_hidden=args.num_classes,
                                    name='fc7')
        zy = mx.sym.pick(fc7, gt_label, axis=1)
        cos_t = zy / s
        t = mx.sym.arccos(cos_t)
        if args.margin_verbose > 0:
            margin_symbols.append(mx.symbol.mean(t))
        var_m = mx.sym.random.uniform(low=args.margin_b, high=args.margin_m, shape=(1,))
        t = mx.sym.broadcast_add(t, var_m)
        body = mx.sym.cos(t)
        new_zy = body * s
        if args.margin_verbose > 0:
            # new_cos_t = new_zy/s
            # margin_symbols.append(mx.symbol.mean(new_cos_t))
            margin_symbols.append(mx.symbol.mean(t))
        diff = new_zy - zy
        diff = mx.sym.expand_dims(diff, 1)
        gt_one_hot = mx.sym.one_hot(gt_label, depth=args.num_classes, on_value=1.0, off_value=0.0)
        body = mx.sym.broadcast_mul(gt_one_hot, diff)
        fc7 = fc7 + body
    elif args.loss_type == 10:  # marginal loss
        nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')
        params = [1.2, 0.3, 1.0]
        n1 = mx.sym.expand_dims(nembedding, axis=1)  # N,1,C
        n2 = mx.sym.expand_dims(nembedding, axis=0)  # 1,N,C
        body = mx.sym.broadcast_sub(n1, n2)  # N,N,C
        body = body * body
        body = mx.sym.sum(body, axis=2)  # N,N
        # body = mx.sym.sqrt(body)
        body = body - params[0]
        mask = mx.sym.Variable('extra')
        body = body * mask
        body = body + params[1]
        # body = mx.sym.maximum(body, 0.0)
        body = mx.symbol.Activation(data=body, act_type='relu')
        body = mx.sym.sum(body)
        body = body / (args.per_batch_size * args.per_batch_size - args.per_batch_size)
        extra_loss = mx.symbol.MakeLoss(body, grad_scale=params[2])
    elif args.loss_type == 11:  # npair loss
        params = [0.9, 0.2]
        nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')
        nembedding = mx.sym.transpose(nembedding)
        nembedding = mx.symbol.reshape(nembedding, (args.emb_size, args.per_identities, args.images_per_identity))
        nembedding = mx.sym.transpose(nembedding, axes=(2, 1, 0))  # 2*id*512
        # nembedding = mx.symbol.reshape(nembedding, (args.emb_size, args.images_per_identity, args.per_identities))
        # nembedding = mx.sym.transpose(nembedding, axes=(1,2,0)) #2*id*512
        n1 = mx.symbol.slice_axis(nembedding, axis=0, begin=0, end=1)
        n2 = mx.symbol.slice_axis(nembedding, axis=0, begin=1, end=2)
        # n1 = []
        # n2 = []
        # for i in xrange(args.per_identities):
        #  _n1 = mx.symbol.slice_axis(nembedding, axis=0, begin=2*i, end=2*i+1)
        #  _n2 = mx.symbol.slice_axis(nembedding, axis=0, begin=2*i+1, end=2*i+2)
        #  n1.append(_n1)
        #  n2.append(_n2)
        # n1 = mx.sym.concat(*n1, dim=0)
        # n2 = mx.sym.concat(*n2, dim=0)
        # rembeddings = mx.symbol.reshape(nembedding, (args.images_per_identity, args.per_identities, 512))
        # n1 = mx.symbol.slice_axis(rembeddings, axis=0, begin=0, end=1)
        # n2 = mx.symbol.slice_axis(rembeddings, axis=0, begin=1, end=2)
        n1 = mx.symbol.reshape(n1, (args.per_identities, args.emb_size))
        n2 = mx.symbol.reshape(n2, (args.per_identities, args.emb_size))
        cosine_matrix = mx.symbol.dot(lhs=n1, rhs=n2, transpose_b=True)  # id*id, id=N of N-pair
        data_extra = mx.sym.Variable('extra')
        data_extra = mx.sym.slice_axis(data_extra, axis=0, begin=0, end=args.per_identities)
        mask = cosine_matrix * data_extra
        # body = mx.sym.mean(mask)
        fii = mx.sym.sum_axis(mask, axis=1)
        fij_fii = mx.sym.broadcast_sub(cosine_matrix, fii)
        fij_fii = mx.sym.exp(fij_fii)
        row = mx.sym.sum_axis(fij_fii, axis=1)
        row = mx.sym.log(row)
        body = mx.sym.mean(row)
        extra_loss = mx.sym.MakeLoss(body)
    elif args.loss_type == 12:  # triplet loss
        nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')
        anchor = mx.symbol.slice_axis(nembedding, axis=0, begin=0, end=args.per_batch_size // 3)
        positive = mx.symbol.slice_axis(nembedding, axis=0, begin=args.per_batch_size // 3,
                                        end=2 * args.per_batch_size // 3)
        negative = mx.symbol.slice_axis(nembedding, axis=0, begin=2 * args.per_batch_size // 3, end=args.per_batch_size)
        ap = anchor - positive
        an = anchor - negative
        ap = ap * ap
        an = an * an
        ap = mx.symbol.sum(ap, axis=1, keepdims=1)  # (T,1)
        an = mx.symbol.sum(an, axis=1, keepdims=1)  # (T,1)
        triplet_loss = mx.symbol.Activation(data=(ap - an + args.triplet_alpha), act_type='relu')
        triplet_loss = mx.symbol.mean(triplet_loss)
        # triplet_loss = mx.symbol.sum(triplet_loss)/(args.per_batch_size//3)
        extra_loss = mx.symbol.MakeLoss(triplet_loss)
    elif args.loss_type == 13:  # triplet loss with angular margin
        m = args.margin_m
        sin_m = math.sin(m)
        cos_m = math.cos(m)
        nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')
        anchor = mx.symbol.slice_axis(nembedding, axis=0, begin=0, end=args.per_batch_size // 3)
        positive = mx.symbol.slice_axis(nembedding, axis=0, begin=args.per_batch_size // 3,
                                        end=2 * args.per_batch_size // 3)
        negative = mx.symbol.slice_axis(nembedding, axis=0, begin=2 * args.per_batch_size // 3, end=args.per_batch_size)
        ap = anchor * positive
        an = anchor * negative
        ap = mx.symbol.sum(ap, axis=1, keepdims=1)  # (T,1)
        an = mx.symbol.sum(an, axis=1, keepdims=1)  # (T,1)

        ap = mx.symbol.arccos(ap)
        an = mx.symbol.arccos(an)
        triplet_loss = mx.symbol.Activation(data=(ap - an + args.margin_m), act_type='relu')

        # body = ap*ap
        # body = 1.0-body
        # body = mx.symbol.sqrt(body)
        # body = body*sin_m
        # ap = ap*cos_m
        # ap = ap-body
        # triplet_loss = mx.symbol.Activation(data = (an-ap), act_type='relu')

        triplet_loss = mx.symbol.mean(triplet_loss)
        extra_loss = mx.symbol.MakeLoss(triplet_loss)
    elif args.loss_type == 9:  # coco loss
        centroids = []
        for i in xrange(args.per_identities):
            xs = mx.symbol.slice_axis(embedding, axis=0, begin=i * args.images_per_identity,
                                      end=(i + 1) * args.images_per_identity)
            mean = mx.symbol.mean(xs, axis=0, keepdims=True)
            mean = mx.symbol.L2Normalization(mean, mode='instance')
            centroids.append(mean)
        centroids = mx.symbol.concat(*centroids, dim=0)
        nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n') * args.coco_scale
        fc7 = mx.symbol.dot(nembedding, centroids, transpose_b=True)  # (batchsize, per_identities)
        # extra_loss = mx.symbol.softmax_cross_entropy(fc7, gt_label, name='softmax_ce')/args.per_batch_size
        # extra_loss = mx.symbol.BlockGrad(extra_loss)
    else:
        # embedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*float(args.loss_type)
        embedding = embedding * 5
        _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0)
        _weight = mx.symbol.L2Normalization(_weight, mode='instance') * 2
        fc7 = mx.sym.LSoftmax(data=embedding, label=gt_label, num_hidden=args.num_classes,
                              weight=_weight,
                              beta=args.beta, margin=args.margin, scale=args.scale,
                              beta_min=args.beta_min, verbose=100, name='fc7')

        # fc7 = mx.sym.Custom(data=embedding, label=gt_label, weight=_weight, num_hidden=args.num_classes,
        #                       beta=args.beta, margin=args.margin, scale=args.scale,
        #                       op_type='ASoftmax', name='fc7')
    if args.loss_type <= 1 and args.incay > 0.0:
        params = [1.e-10]
        sel = mx.symbol.argmax(data=fc7, axis=1)
        sel = (sel == gt_label)
        norm = embedding * embedding
        norm = mx.symbol.sum(norm, axis=1)
        norm = norm + params[0]
        feature_incay = sel / norm
        feature_incay = mx.symbol.mean(feature_incay) * args.incay
        extra_loss = mx.symbol.MakeLoss(feature_incay)
    # out = softmax
    # l2_embedding = mx.symbol.L2Normalization(embedding)

    # ce = mx.symbol.softmax_cross_entropy(fc7, gt_label, name='softmax_ce')/args.per_batch_size
    # out = mx.symbol.Group([mx.symbol.BlockGrad(embedding), softmax, mx.symbol.BlockGrad(ce)])
    out_list = [mx.symbol.BlockGrad(embedding)]
    softmax = None
    if args.loss_type < 10:
        softmax = mx.symbol.SoftmaxOutput(data=fc7, label=gt_label, name='softmax', normalization='valid')
        out_list.append(softmax)
        if args.logits_verbose > 0:
            logits = mx.symbol.softmax(data=fc7)
            logits = mx.sym.pick(logits, gt_label, axis=1)
            margin_symbols.append(logits)
            # logit_max = mx.sym.max(logits)
            # logit_min = mx.sym.min(logits)
            # margin_symbols.append(logit_max)
            # margin_symbols.append(logit_min)
    if softmax is None:
        out_list.append(mx.sym.BlockGrad(gt_label))
    if extra_loss is not None:
        out_list.append(extra_loss)
    for _sym in margin_symbols:
        _sym = mx.sym.BlockGrad(_sym)
        out_list.append(_sym)
    out = mx.symbol.Group(out_list)
    return (out, arg_params, aux_params)


def train_net(args):
    ctx = []
    cvd = os.environ['CUDA_VISIBLE_DEVICES'].strip()
    if len(cvd) > 0:
        for i in xrange(len(cvd.split(','))):
            ctx.append(mx.gpu(i))
    if len(ctx) == 0:
        ctx = [mx.cpu()]
        print('use cpu')
    else:
        print('gpu num:', len(ctx))
    prefix = args.prefix
    prefix_dir = os.path.dirname(prefix)
    if not os.path.exists(prefix_dir):
        os.makedirs(prefix_dir)
    end_epoch = args.end_epoch
    args.ctx_num = len(ctx)
    args.num_layers = int(args.network[1:])
    print('num_layers', args.num_layers)
    if args.per_batch_size == 0:
        args.per_batch_size = 128
        if args.loss_type == 10:
            args.per_batch_size = 256
    args.batch_size = args.per_batch_size * args.ctx_num
    args.rescale_threshold = 0
    args.image_channel = 3
    ppatch = [int(x) for x in args.patch.split('_')]
    assert len(ppatch) == 5

    os.environ['BETA'] = str(args.beta)
    data_dir_list = args.data_dir.split(',')
    if args.loss_type != 12 and args.loss_type != 13:
        assert len(data_dir_list) == 1
    data_dir = data_dir_list[0]
    args.use_val = False
    path_imgrec = None
    path_imglist = None
    val_rec = None
    prop = face_image.load_property(data_dir)
    args.num_classes = prop.num_classes
    image_size = prop.image_size
    args.image_h = image_size[0]
    args.image_w = image_size[1]
    print('image_size', image_size)

    assert (args.num_classes > 0)
    print('num_classes', args.num_classes)
    args.coco_scale = 0.5 * math.log(float(args.num_classes - 1)) + 3

    # path_imglist = "/raid5data/dplearn/MS-Celeb-Aligned/lst2"
    path_imgrec = os.path.join(data_dir, "train.rec")
    val_rec = os.path.join(data_dir, "val.rec")
    if os.path.exists(val_rec) and args.loss_type < 10:
        args.use_val = True
    else:
        val_rec = None
    # args.use_val = False

    if args.loss_type == 1 and args.num_classes > 20000:
        args.beta_freeze = 5000
        args.gamma = 0.06

    if args.loss_type < 9:
        assert args.images_per_identity == 0
    else:
        if args.images_per_identity == 0:
            if args.loss_type == 11:
                args.images_per_identity = 2
            elif args.loss_type == 10 or args.loss_type == 9:
                args.images_per_identity = 16
            elif args.loss_type == 12 or args.loss_type == 13:
                args.images_per_identity = 5
                assert args.per_batch_size % 3 == 0
        assert args.images_per_identity >= 2
        args.per_identities = int(args.per_batch_size / args.images_per_identity)

    print('Called with argument:', args)

    data_shape = (args.image_channel, image_size[0], image_size[1])
    mean = None

    begin_epoch = 0
    base_lr = args.lr
    base_wd = args.wd
    base_mom = args.mom
    if len(args.pretrained) == 0:
        arg_params = None
        aux_params = None
        sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params)
    else:
        vec = args.pretrained.split(',')
        print('loading', vec)
        _, arg_params, aux_params = mx.model.load_checkpoint(vec[0], int(vec[1]))
        sym, arg_params, aux_params = get_symbol(args, arg_params, aux_params)

    data_extra = None
    hard_mining = False
    triplet_params = None
    coco_mode = False
    if args.loss_type == 10:
        hard_mining = True
        _shape = (args.batch_size, args.per_batch_size)
        data_extra = np.full(_shape, -1.0, dtype=np.float32)
        c = 0
        while c < args.batch_size:
            a = 0
            while a < args.per_batch_size:
                b = a + args.images_per_identity
                data_extra[(c + a):(c + b), a:b] = 1.0
                # print(c+a, c+b, a, b)
                a = b
            c += args.per_batch_size
    elif args.loss_type == 11:
        data_extra = np.zeros((args.batch_size, args.per_identities), dtype=np.float32)
        c = 0
        while c < args.batch_size:
            for i in xrange(args.per_identities):
                data_extra[c + i][i] = 1.0
            c += args.per_batch_size
    elif args.loss_type == 12 or args.loss_type == 13:
        triplet_params = [args.triplet_bag_size, args.triplet_alpha, args.triplet_max_ap]
    elif args.loss_type == 9:
        coco_mode = True

    label_name = 'softmax_label'
    label_shape = (args.batch_size,)
    if args.output_c2c:
        label_shape = (args.batch_size, 2)
    if data_extra is None:
        model = mx.mod.Module(
            context=ctx,
            symbol=sym,
        )
    else:
        data_names = ('data', 'extra')
        # label_name = ''
        model = mx.mod.Module(
            context=ctx,
            symbol=sym,
            data_names=data_names,
            label_names=(label_name,),
        )

    if args.use_val:
        val_dataiter = FaceImageIter(
            batch_size=args.batch_size,
            data_shape=data_shape,
            path_imgrec=val_rec,
            # path_imglist         = val_path,
            shuffle=False,
            rand_mirror=False,
            mean=mean,
            ctx_num=args.ctx_num,
            data_extra=data_extra,
        )
    else:
        val_dataiter = None

    if len(data_dir_list) == 1 and args.loss_type != 12 and args.loss_type != 13:
        train_dataiter = FaceImageIter(
            batch_size=args.batch_size,
            data_shape=data_shape,
            path_imgrec=path_imgrec,
            shuffle=True,
            rand_mirror=args.rand_mirror,
            mean=mean,
            cutoff=args.cutoff,
            c2c_threshold=args.c2c_threshold,
            output_c2c=args.output_c2c,
            c2c_mode=args.c2c_mode,
            limit=args.train_limit,
            ctx_num=args.ctx_num,
            images_per_identity=args.images_per_identity,
            data_extra=data_extra,
            hard_mining=hard_mining,
            triplet_params=triplet_params,
            coco_mode=coco_mode,
            mx_model=model,
            label_name=label_name,
        )
    else:
        iter_list = []
        for _data_dir in data_dir_list:
            _path_imgrec = os.path.join(_data_dir, "train.rec")
            _dataiter = FaceImageIter(
                batch_size=args.batch_size,
                data_shape=data_shape,
                path_imgrec=_path_imgrec,
                shuffle=True,
                rand_mirror=args.rand_mirror,
                mean=mean,
                cutoff=args.cutoff,
                c2c_threshold=args.c2c_threshold,
                output_c2c=args.output_c2c,
                c2c_mode=args.c2c_mode,
                limit=args.train_limit,
                ctx_num=args.ctx_num,
                images_per_identity=args.images_per_identity,
                data_extra=data_extra,
                hard_mining=hard_mining,
                triplet_params=triplet_params,
                coco_mode=coco_mode,
                mx_model=model,
                label_name=label_name,
            )
            iter_list.append(_dataiter)
        iter_list.append(_dataiter)
        train_dataiter = FaceImageIterList(iter_list)

    if args.loss_type < 10:
        _metric = AccMetric()
    else:
        _metric = LossValueMetric()
    eval_metrics = [mx.metric.create(_metric)]

    if args.network[0] == 'r':
        initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="out", magnitude=2)  # resnet style
    elif args.network[0] == 'i' or args.network[0] == 'x':
        initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="in", magnitude=2)  # inception
    else:
        initializer = mx.init.Xavier(rnd_type='uniform', factor_type="in", magnitude=2)
    _rescale = 1.0 / args.ctx_num
    if args.noise_sgd > 0.0:
        print('use noise sgd')
        opt = NoiseSGD(scale=args.noise_sgd, learning_rate=base_lr, momentum=base_mom, wd=base_wd,
                       rescale_grad=_rescale)
    else:
        opt = optimizer.SGD(learning_rate=base_lr, momentum=base_mom, wd=base_wd, rescale_grad=_rescale)
    som = 20
    if args.loss_type == 12 or args.loss_type == 13:
        som = 2
    _cb = mx.callback.Speedometer(args.batch_size, som)

    ver_list = []
    ver_name_list = []
    for name in args.target.split(','):
        path = os.path.join(data_dir, name + ".bin")
        if os.path.exists(path):
            data_set = verification.load_bin(path, image_size)
            ver_list.append(data_set)
            ver_name_list.append(name)
            print('ver', name)

    def ver_test(nbatch):
        results = []
        for i in xrange(len(ver_list)):
            acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test(ver_list[i], model, args.batch_size, 10,
                                                                               data_extra, label_shape)
            print('[%s][%d]XNorm: %f' % (ver_name_list[i], nbatch, xnorm))
            # print('[%s][%d]Accuracy: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc1, std1))
            print('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' % (ver_name_list[i], nbatch, acc2, std2))
            results.append(acc2)
        return results

    def val_test():
        acc = AccMetric()
        val_metric = mx.metric.create(acc)
        val_metric.reset()
        val_dataiter.reset()
        for i, eval_batch in enumerate(val_dataiter):
            model.forward(eval_batch, is_train=False)
            model.update_metric(val_metric, eval_batch.label)
        acc_value = val_metric.get_name_value()[0][1]
        print('VACC: %f' % (acc_value))

    highest_acc = [0.0, 0.0]  # lfw and target
    # for i in xrange(len(ver_list)):
    #  highest_acc.append(0.0)
    global_step = [0]
    save_step = [0]
    if len(args.lr_steps) == 0:
        lr_steps = [40000, 60000, 80000]
        if args.loss_type >= 1 and args.loss_type <= 7:
            lr_steps = [100000, 140000, 160000]
        p = 512.0 / args.batch_size
        for l in xrange(len(lr_steps)):
            lr_steps[l] = int(lr_steps[l] * p)
    else:
        lr_steps = [int(x) for x in args.lr_steps.split(',')]
    print('lr_steps', lr_steps)

    def _batch_callback(param):
        # global global_step
        global_step[0] += 1
        mbatch = global_step[0]
        for _lr in lr_steps:
            if mbatch == args.beta_freeze + _lr:
                opt.lr *= 0.1
                print('lr change to', opt.lr)
                break

        _cb(param)
        if mbatch % 1000 == 0:
            print('lr-batch-epoch:', opt.lr, param.nbatch, param.epoch)

        if mbatch >= 0 and mbatch % args.verbose == 0:
            acc_list = ver_test(mbatch)
            save_step[0] += 1
            msave = save_step[0]
            do_save = False
            if len(acc_list) > 0:
                lfw_score = acc_list[0]
                if lfw_score > highest_acc[0]:
                    highest_acc[0] = lfw_score
                    if lfw_score >= 0.998:
                        do_save = True
                if acc_list[-1] >= highest_acc[-1]:
                    highest_acc[-1] = acc_list[-1]
                    if lfw_score >= 0.99:
                        do_save = True
            if args.ckpt == 0:
                do_save = False
            elif args.ckpt > 1:
                do_save = True
            # for i in xrange(len(acc_list)):
            #  acc = acc_list[i]
            #  if acc>=highest_acc[i]:
            #    highest_acc[i] = acc
            #    if lfw_score>=0.99:
            #      do_save = True
            # if args.loss_type==1 and mbatch>lr_steps[-1] and mbatch%10000==0:
            #  do_save = True
            if do_save:
                print('saving', msave)
                if val_dataiter is not None:
                    val_test()
                arg, aux = model.get_params()
                mx.model.save_checkpoint(prefix, msave, model.symbol, arg, aux)
                # if acc>=highest_acc[0]:
                #  lfw_npy = "%s-lfw-%04d" % (prefix, msave)
                #  X = np.concatenate(embeddings_list, axis=0)
                #  print('saving lfw npy', X.shape)
                #  np.save(lfw_npy, X)
            print('[%d]Accuracy-Highest: %1.5f' % (mbatch, highest_acc[-1]))
        if mbatch <= args.beta_freeze:
            _beta = args.beta
        else:
            move = max(0, mbatch - args.beta_freeze)
            _beta = max(args.beta_min, args.beta * math.pow(1 + args.gamma * move, -1.0 * args.power))
        # print('beta', _beta)
        os.environ['BETA'] = str(_beta)
        if args.max_steps > 0 and mbatch > args.max_steps:
            sys.exit(0)

    # epoch_cb = mx.callback.do_checkpoint(prefix, 1)
    epoch_cb = None

    # def _epoch_callback(epoch, sym, arg_params, aux_params):
    #  print('epoch-end', epoch)

    model.fit(train_dataiter,
              begin_epoch=begin_epoch,
              num_epoch=end_epoch,
              eval_data=val_dataiter,
              eval_metric=eval_metrics,
              kvstore='device',
              optimizer=opt,
              # optimizer_params   = optimizer_params,
              initializer=initializer,
              arg_params=arg_params,
              aux_params=aux_params,
              allow_missing=True,
              batch_end_callback=_batch_callback,
              epoch_end_callback=epoch_cb)


def main():
    # time.sleep(3600*6.5)
    global args
    args = parse_args()
    train_net(args)


if __name__ == '__main__':
    main()
